{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BGE-M3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Installation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the required packages in your environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "%pip install -U transformers FlagEmbedding accelerate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. BGE-M3 structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch, os\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-m3\")\n",
    "raw_model = AutoModel.from_pretrained(\"BAAI/bge-m3\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The base model of BGE-M3 is [XLM-RoBERTa-large](https://huggingface.co/FacebookAI/xlm-roberta-large), which is a multilingual version of RoBERTa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "XLMRobertaModel(\n",
       "  (embeddings): XLMRobertaEmbeddings(\n",
       "    (word_embeddings): Embedding(250002, 1024, padding_idx=1)\n",
       "    (position_embeddings): Embedding(8194, 1024, padding_idx=1)\n",
       "    (token_type_embeddings): Embedding(1, 1024)\n",
       "    (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (encoder): XLMRobertaEncoder(\n",
       "    (layer): ModuleList(\n",
       "      (0-23): 24 x XLMRobertaLayer(\n",
       "        (attention): XLMRobertaAttention(\n",
       "          (self): XLMRobertaSelfAttention(\n",
       "            (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): XLMRobertaSelfOutput(\n",
       "            (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "            (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): XLMRobertaIntermediate(\n",
       "          (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): XLMRobertaOutput(\n",
       "          (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "          (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pooler): XLMRobertaPooler(\n",
       "    (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "    (activation): Tanh()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Multi-Functionality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 240131.91it/s]\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import BGEM3FlagModel\n",
    "\n",
    "model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)\n",
    "\n",
    "sentences_1 = [\"What is BGE M3?\", \"Defination of BM25\"]\n",
    "sentences_2 = [\"BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.\", \n",
    "               \"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Dense Retrieval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using BGE M3 for dense embedding has similar steps to BGE or BGE 1.5 models.\n",
    "\n",
    "Use the normalized hidden state of the special token [CLS] as the embedding:\n",
    "\n",
    "$$e_q = norm(H_q[0])$$\n",
    "\n",
    "Then compute the relevance score between the query and passage:\n",
    "\n",
    "$$s_{dense}=f_{sim}(e_p, e_q)$$\n",
    "\n",
    "where $e_p, e_q$ are the embedding vectors of passage and query, respectively.\n",
    "\n",
    "$f_{sim}$ is the score function (such as inner product and L2 distance) for comupting two embeddings' similarity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.6259035  0.34749585]\n",
      " [0.349868   0.6782462 ]]\n"
     ]
    }
   ],
   "source": [
    "# If you don't need such a long length of 8192 input tokens, you can set max_length to a smaller value to speed up encoding.\n",
    "embeddings_1 = model.encode(sentences_1, max_length=10)['dense_vecs']\n",
    "embeddings_2 = model.encode(sentences_2, max_length=100)['dense_vecs']\n",
    "\n",
    "# compute the similarity scores\n",
    "s_dense = embeddings_1 @ embeddings_2.T\n",
    "print(s_dense)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Sparse Retrieval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set `return_sparse` to true to make the model return sparse vector.  If a term token appears multiple times in the sentence, we only retain its max weight.\n",
    "\n",
    "BGE-M3 generates sparce embeddings by adding a linear layer and a ReLU activation function following the hidden states:\n",
    "\n",
    "$$w_{qt} = \\text{Relu}(W_{lex}^T H_q [i])$$\n",
    "\n",
    "where $W_{lex}$ representes the weights of linear layer and $H_q[i]$ is the encoder's output of the $i^{th}$ token."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'What': 0.08362077, 'is': 0.081469566, 'B': 0.12964639, 'GE': 0.25186998, 'M': 0.17001738, '3': 0.26957875, '?': 0.040755156}, {'De': 0.050144322, 'fin': 0.13689369, 'ation': 0.045134712, 'of': 0.06342201, 'BM': 0.25167602, '25': 0.33353207}]\n"
     ]
    }
   ],
   "source": [
    "output_1 = model.encode(sentences_1, return_sparse=True)\n",
    "output_2 = model.encode(sentences_2, return_sparse=True)\n",
    "\n",
    "# you can see the weight for each token:\n",
    "print(model.convert_id_to_token(output_1['lexical_weights']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Based on the tokens' weights of query and passage, the relevance score between them is computed by the joint importance of the co-existed terms within the query and passage:\n",
    "\n",
    "$$s_{lex} = \\sum_{t\\in q\\cap p}(w_{qt} * w_{pt})$$\n",
    "\n",
    "where $w_{qt}, w_{pt}$ are the importance weights of each co-existed term $t$ in query and passage, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.19554448500275612\n",
      "0.00880391988903284\n"
     ]
    }
   ],
   "source": [
    "# compute the scores via lexical mathcing\n",
    "s_lex_10_20 = model.compute_lexical_matching_score(output_1['lexical_weights'][0], output_2['lexical_weights'][0])\n",
    "s_lex_10_21 = model.compute_lexical_matching_score(output_1['lexical_weights'][0], output_2['lexical_weights'][1])\n",
    "\n",
    "print(s_lex_10_20)\n",
    "print(s_lex_10_21)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Multi-Vector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The multi-vector method utilizes the entire output embeddings for the representation of query $E_q$ and passage $E_p$.\n",
    "\n",
    "$$E_q = norm(W_{mul}^T H_q)$$\n",
    "$$E_p = norm(W_{mul}^T H_p)$$\n",
    "\n",
    "where $W_{mul}$ is the learnable projection matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(8, 1024)\n",
      "(30, 1024)\n"
     ]
    }
   ],
   "source": [
    "output_1 = model.encode(sentences_1, return_dense=True, return_sparse=True, return_colbert_vecs=True)\n",
    "output_2 = model.encode(sentences_2, return_dense=True, return_sparse=True, return_colbert_vecs=True)\n",
    "\n",
    "print(f\"({len(output_1['colbert_vecs'][0])}, {len(output_1['colbert_vecs'][0][0])})\")\n",
    "print(f\"({len(output_2['colbert_vecs'][0])}, {len(output_2['colbert_vecs'][0][0])})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Following ColBert, we use late-interaction to compute the fine-grained relevance score:\n",
    "\n",
    "$$s_{mul}=\\frac{1}{N}\\sum_{i=1}^N\\max_{j=1}^M E_q[i]\\cdot E_p^T[j]$$\n",
    "\n",
    "where $E_q, E_p$ are the entire output embeddings of query and passage, respectively.\n",
    "\n",
    "This is a summation of average of maximum similarity of each $v\\in E_q$ with vectors in $E_p$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7796662449836731\n",
      "0.4621177911758423\n"
     ]
    }
   ],
   "source": [
    "s_mul_10_20 = model.colbert_score(output_1['colbert_vecs'][0], output_2['colbert_vecs'][0]).item()\n",
    "s_mul_10_21 = model.colbert_score(output_1['colbert_vecs'][0], output_2['colbert_vecs'][1]).item()\n",
    "\n",
    "print(s_mul_10_20)\n",
    "print(s_mul_10_21)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4 Hybrid Ranking"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BGE-M3's multi-functionality gives the possibility of hybrid ranking to improve retrieval. Firstly, due to the heavy cost of multi-vector method, we can retrieve the candidate results by either of the dense or sparse method. Then, to get the final result, we can rerank the candidates based on the integrated relevance score:\n",
    "\n",
    "$$s_{rank} = w_1\\cdot s_{dense}+w_2\\cdot s_{lex} + w_3\\cdot s_{mul}$$\n",
    "\n",
    "where the values chosen for $w_1, w_2$ and $w_3$ varies depending on the downstream scenario (here 1/3 is just for demonstration)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5337047390639782\n",
      "0.27280585498859483\n"
     ]
    }
   ],
   "source": [
    "s_rank_10_20 = 1/3 * s_dense[0][0] + 1/3 * s_lex_10_20 + 1/3 * s_mul_10_20\n",
    "s_rank_10_21 = 1/3 * s_dense[0][1] + 1/3 * s_lex_10_21 + 1/3 * s_mul_10_21\n",
    "\n",
    "print(s_rank_10_20)\n",
    "print(s_rank_10_21)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
